# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

# Enables postponed evaluation of annotations (for string-based type hints)
from __future__ import annotations

import inspect
import time
from typing import Any, ClassVar, Dict
from uuid import UUID, uuid4

from pydantic import BaseModel, ConfigDict, Field

from camel.messages import BaseMessage, FunctionCallingMessage, OpenAIMessage
from camel.types import OpenAIBackendRole


class MemoryRecord(BaseModel):
    r"""The basic message storing unit in the CAMEL memory system.

    Attributes:
        message (BaseMessage): The main content of the record.
        role_at_backend (OpenAIBackendRole): An enumeration value representing
            the role this message played at the OpenAI backend. Note that this
            value is different from the :obj:`RoleType` used in the CAMEL role
            playing system.
        uuid (UUID, optional): A universally unique identifier for this record.
            This is used to uniquely identify this record in the memory system.
            If not given, it will be assigned with a random UUID.
        extra_info (Dict[str, str], optional): A dictionary of additional
            key-value pairs that provide more information. If not given, it
            will be an empty `Dict`.
        timestamp (float, optional): The timestamp when the record was created.
        agent_id (str): The identifier of the agent associated with this
            memory.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)

    message: BaseMessage
    role_at_backend: OpenAIBackendRole
    uuid: UUID = Field(default_factory=uuid4)
    extra_info: Dict[str, str] = Field(default_factory=dict)
    timestamp: float = Field(
        default_factory=lambda: time.time_ns()
        / 1_000_000_000  # Nanosecond precision
    )
    agent_id: str = Field(default="")

    _MESSAGE_TYPES: ClassVar[dict] = {
        "BaseMessage": BaseMessage,
        "FunctionCallingMessage": FunctionCallingMessage,
    }

    # Cache for constructor parameters (performance optimization)
    _constructor_params_cache: ClassVar[Dict[str, set]] = {}

    @classmethod
    def _get_constructor_params(cls, message_cls) -> set:
        """Get constructor parameters for a message class with caching."""
        cls_name = message_cls.__name__
        if cls_name not in cls._constructor_params_cache:
            sig = inspect.signature(message_cls.__init__)
            cls._constructor_params_cache[cls_name] = set(
                sig.parameters.keys()
            ) - {'self'}
        return cls._constructor_params_cache[cls_name]

    @classmethod
    def from_dict(cls, record_dict: Dict[str, Any]) -> "MemoryRecord":
        r"""Reconstruct a :obj:`MemoryRecord` from the input dict.

        Args:
            record_dict(Dict[str, Any]): A dict generated by :meth:`to_dict`.
        """
        from camel.types import OpenAIBackendRole, RoleType

        message_cls = cls._MESSAGE_TYPES[record_dict["message"]["__class__"]]
        data = record_dict["message"].copy()
        data.pop("__class__")

        # Convert role_type string to enum
        if "role_type" in data and isinstance(data["role_type"], str):
            data["role_type"] = RoleType(data["role_type"])

        # Deserialize image_list from base64 strings/URLs back to PIL Images/
        # URLs
        if "image_list" in data and data["image_list"] is not None:
            import base64
            from io import BytesIO

            from PIL import Image

            image_objects = []
            for img_item in data["image_list"]:
                if isinstance(img_item, dict):
                    # New format with type indicator
                    if img_item["type"] == "url":
                        # URL string, keep as-is
                        image_objects.append(img_item["data"])
                    else:  # type == "base64"
                        # Base64 encoded image, convert to PIL Image
                        img_bytes = base64.b64decode(img_item["data"])
                        img = Image.open(BytesIO(img_bytes))
                        # Restore the format attribute if it was saved
                        if "format" in img_item:
                            img.format = img_item["format"]
                        image_objects.append(img)
                else:
                    # Legacy format: assume it's a base64 string
                    img_bytes = base64.b64decode(img_item)
                    img = Image.open(BytesIO(img_bytes))
                    image_objects.append(img)
            data["image_list"] = image_objects

        # Deserialize video_bytes from base64 string
        if "video_bytes" in data and data["video_bytes"] is not None:
            import base64

            data["video_bytes"] = base64.b64decode(data["video_bytes"])

        # Get valid constructor parameters (cached)
        valid_params = cls._get_constructor_params(message_cls)

        # Separate constructor args from extra fields
        kwargs = {k: v for k, v in data.items() if k in valid_params}
        extra_fields = {k: v for k, v in data.items() if k not in valid_params}

        # Handle meta_dict properly: merge existing meta_dict with extra fields
        existing_meta = kwargs.get("meta_dict", {}) or {}
        if extra_fields:
            # Extra fields take precedence, but preserve existing meta_dict
            # structure
            merged_meta = {**existing_meta, **extra_fields}
            kwargs["meta_dict"] = merged_meta
        elif not existing_meta:
            kwargs["meta_dict"] = None

        # Convert role_at_backend
        role_at_backend = record_dict["role_at_backend"]
        if isinstance(role_at_backend, str):
            role_at_backend = OpenAIBackendRole(role_at_backend)

        return cls(
            uuid=UUID(record_dict["uuid"]),
            message=message_cls(**kwargs),
            role_at_backend=role_at_backend,
            extra_info=record_dict["extra_info"],
            timestamp=record_dict["timestamp"],
            agent_id=record_dict["agent_id"],
        )

    def to_dict(self) -> Dict[str, Any]:
        r"""Convert the :obj:`MemoryRecord` to a dict for serialization
        purposes.
        """
        return {
            "uuid": str(self.uuid),
            "message": {
                "__class__": self.message.__class__.__name__,
                **self.message.to_dict(),
            },
            "role_at_backend": self.role_at_backend.value
            if hasattr(self.role_at_backend, "value")
            else self.role_at_backend,
            "extra_info": self.extra_info,
            "timestamp": self.timestamp,
            "agent_id": self.agent_id,
        }

    def to_openai_message(self) -> OpenAIMessage:
        r"""Converts the record to an :obj:`OpenAIMessage` object."""
        return self.message.to_openai_message(self.role_at_backend)


class ContextRecord(BaseModel):
    r"""The result of memory retrieving."""

    memory_record: MemoryRecord
    score: float
    timestamp: float = Field(
        default_factory=lambda: time.time_ns()
        / 1_000_000_000  # Nanosecond precision
    )
